{
 "cells": [
  {
   "metadata": {
    "id": "eufgAL3xy6Zm",
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "cell_type": "markdown",
   "source": [
    "# Train Image Similarity"
   ]
  },
  {
   "metadata": {
    "id": "QqCQpjnQz9CI",
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "cell_type": "markdown",
   "source": [
    "## Mount drive etc"
   ]
  },
  {
   "metadata": {
    "id": "StLnj_Q_y3vy",
    "trusted": true,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "cell_type": "code",
   "source": [
    "!nvidia-smi"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "metadata": {
    "id": "_8hofRqTz-5J",
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "cell_type": "markdown",
   "source": [
    "## Run the Training Script"
   ]
  },
  {
   "metadata": {
    "id": "bKUAY2Z790q_",
    "trusted": true,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "cell_type": "code",
   "source": [
    "import torch\n",
    "from PIL import Image\n",
    "import os\n",
    "from tqdm import tqdm\n",
    "from torch.utils.data import Dataset\n",
    "\n",
    "import torchvision.transforms as T\n",
    "from tqdm import tqdm\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n"
   ],
   "execution_count": 1,
   "outputs": []
  },
  {
   "metadata": {
    "id": "Vt3SI_V69yuG",
    "trusted": true,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "cell_type": "code",
   "source": [
    "class FolderDataset(Dataset):\n",
    "    def __init__(self, main_dir, transform=None):\n",
    "        self.main_dir = main_dir\n",
    "        self.transform = transform\n",
    "        self.all_imgs = os.listdir(main_dir)\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.all_imgs)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        img_loc = os.path.join(self.main_dir, self.all_imgs[idx])\n",
    "        image = Image.open(img_loc).convert(\"RGB\")\n",
    "\n",
    "        if self.transform is not None:\n",
    "            tensor_image = self.transform(image)\n",
    "\n",
    "        return tensor_image, tensor_image\n"
   ],
   "execution_count": 2,
   "outputs": []
  },
  {
   "metadata": {
    "id": "TfTyqEPL9ttJ",
    "trusted": true,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "cell_type": "code",
   "source": [
    "\n",
    "class ConvEncoder(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        # self.img_size = img_size\n",
    "        self.conv1 = nn.Conv2d(3, 16, (3, 3), padding=(1, 1))\n",
    "        self.relu1 = nn.ReLU(inplace=True)\n",
    "        self.maxpool1 = nn.MaxPool2d((2, 2))\n",
    "\n",
    "        self.conv2 = nn.Conv2d(16, 32, (3, 3), padding=(1, 1))\n",
    "        self.relu2 = nn.ReLU(inplace=True)\n",
    "        self.maxpool2 = nn.MaxPool2d((2, 2))\n",
    "\n",
    "        self.conv3 = nn.Conv2d(32, 64, (3, 3), padding=(1, 1))\n",
    "        self.relu3 = nn.ReLU(inplace=True)\n",
    "        self.maxpool3 = nn.MaxPool2d((2, 2))\n",
    "\n",
    "    def forward(self, x):\n",
    "        # Downscale the image with conv maxpool etc.\n",
    "        # print(x.shape)\n",
    "        x = self.conv1(x)\n",
    "        x = self.relu1(x)\n",
    "        x = self.maxpool1(x)\n",
    "\n",
    "        # print(x.shape)\n",
    "\n",
    "        x = self.conv2(x)\n",
    "        x = self.relu2(x)\n",
    "        x = self.maxpool2(x)\n",
    "\n",
    "        # print(x.shape)\n",
    "\n",
    "        x = self.conv3(x)\n",
    "        x = self.relu3(x)\n",
    "        x = self.maxpool3(x)\n",
    "\n",
    "        # print(x.shape)\n",
    "        return x\n",
    "\n",
    "\n",
    "class ConvDecoder(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.deconv1 = nn.ConvTranspose2d(64, 32, (2, 2), stride=(2, 2))\n",
    "        # self.upsamp1 = nn.UpsamplingBilinear2d(2)\n",
    "        self.relu1 = nn.ReLU(inplace=True)\n",
    "\n",
    "        self.deconv2 = nn.ConvTranspose2d(32, 16, (2, 2), stride=(2, 2))\n",
    "#         self.upsamp1 = nn.UpsamplingBilinear2d(2)\n",
    "        self.relu2 = nn.ReLU(inplace=True)\n",
    "\n",
    "        self.deconv3 = nn.ConvTranspose2d(16, 3, (2, 2), stride=(2, 2))\n",
    "#         self.upsamp1 = nn.UpsamplingBilinear2d(2)\n",
    "        self.relu3 = nn.ReLU(inplace=True)\n",
    "\n",
    "    def forward(self, x):\n",
    "        # print(x.shape)\n",
    "        x = self.deconv1(x)\n",
    "        x = self.relu1(x)\n",
    "        # print(x.shape)\n",
    "\n",
    "        x = self.deconv2(x)\n",
    "        x = self.relu2(x)\n",
    "        # print(x.shape)\n",
    "\n",
    "        x = self.deconv3(x)\n",
    "        x = self.relu3(x)\n",
    "        # print(x.shape)\n",
    "        return x\n"
   ],
   "execution_count": 7,
   "outputs": []
  },
  {
   "metadata": {
    "id": "IjLkAZy49Z7v",
    "trusted": true,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "cell_type": "code",
   "source": [
    "IMG_PATH = \"../input/animals-data/dataset/\"\n",
    "IMG_HEIGHT = 512  # The images are already resized here\n",
    "IMG_WIDTH = 512  # The images are already resized here\n",
    "\n",
    "SEED = 42\n",
    "TRAIN_RATIO = 0.75\n",
    "VAL_RATIO = 1 - TRAIN_RATIO\n",
    "SHUFFLE_BUFFER_SIZE = 100\n",
    "\n",
    "LEARNING_RATE = 1e-3\n",
    "EPOCHS = 2\n",
    "TRAIN_BATCH_SIZE = 32  # Let's see, I don't have GPU, Google Colab is best hope\n",
    "TEST_BATCH_SIZE = 32  # Let's see, I don't have GPU, Google Colab is best hope\n",
    "FULL_BATCH_SIZE = 32\n",
    "\n",
    "AUTOENCODER_MODEL_PATH = \"baseline_autoencoder.pt\"\n",
    "ENCODER_MODEL_PATH = \"baseline_encoder.pt\"\n",
    "DECODER_MODEL_PATH = \"baseline_decoder.pt\"\n",
    "EMBEDDING_SHAPE = (1, 64, 64, 64)\n",
    "# TEST_RATIO = 0.2\n"
   ],
   "execution_count": 8,
   "outputs": []
  },
  {
   "metadata": {
    "id": "Vn7OetpD4gjZ",
    "trusted": true,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "cell_type": "code",
   "source": [
    "\"\"\"\n",
    "I can write this if we need custom training loop etc.\n",
    "I usually use this in PyTorch.\n",
    "\"\"\"\n",
    "\n",
    "__all__ = [\"train_step\", \"val_step\", \"create_embedding\"]\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "\n",
    "# device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "\n",
    "def train_step(encoder, decoder, train_loader, loss_fn, optimizer, device):\n",
    "    # device = \"cuda\"\n",
    "    encoder.train()\n",
    "    decoder.train()\n",
    "\n",
    "    # print(device)\n",
    "\n",
    "    for batch_idx, (train_img, target_img) in enumerate(train_loader):\n",
    "        train_img = train_img.to(device)\n",
    "        target_img = target_img.to(device)\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "        enc_output = encoder(train_img)\n",
    "        dec_output = decoder(enc_output)\n",
    "\n",
    "        loss = loss_fn(dec_output, target_img)\n",
    "        loss.backward()\n",
    "\n",
    "        optimizer.step()\n",
    "\n",
    "    return loss.item()\n",
    "\n",
    "\n",
    "def val_step(encoder, decoder, val_loader, loss_fn, device):\n",
    "    encoder.eval()\n",
    "    decoder.eval()\n",
    "\n",
    "    with torch.no_grad():\n",
    "        for batch_idx, (train_img, target_img) in enumerate(val_loader):\n",
    "            train_img = train_img.to(device)\n",
    "            target_img = target_img.to(device)\n",
    "\n",
    "            enc_output = encoder(train_img)\n",
    "            dec_output = decoder(enc_output)\n",
    "\n",
    "            loss = loss_fn(dec_output, target_img)\n",
    "\n",
    "    return loss.item()"
   ],
   "execution_count": 9,
   "outputs": []
  },
  {
   "metadata": {
    "id": "DXPDCa8a9R3c",
    "trusted": true,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "cell_type": "code",
   "source": [
    "if torch.cuda.is_available():\n",
    "    device = \"cuda\"\n",
    "else:\n",
    "    device = \"cpu\"\n",
    "\n",
    "# print(\"Setting Seed for the run, seed = {}\".format(config.SEED))\n",
    "\n",
    "# seed_everything(config.SEED)\n",
    "\n",
    "transforms = T.Compose([T.ToTensor()])\n",
    "print(\"------------ Creating Dataset ------------\")\n",
    "full_dataset = FolderDataset(IMG_PATH, transforms)\n",
    "\n",
    "train_size = int(TRAIN_RATIO * len(full_dataset))\n",
    "val_size = len(full_dataset) - train_size\n",
    "\n",
    "train_dataset, val_dataset = torch.utils.data.random_split(\n",
    "    full_dataset, [train_size, val_size]\n",
    ")\n",
    "\n",
    "print(\"------------ Dataset Created ------------\")\n",
    "print(\"------------ Creating DataLoader ------------\")\n",
    "train_loader = torch.utils.data.DataLoader(\n",
    "    train_dataset, batch_size=TRAIN_BATCH_SIZE, shuffle=True, drop_last=True\n",
    ")\n",
    "val_loader = torch.utils.data.DataLoader(\n",
    "    val_dataset, batch_size=TEST_BATCH_SIZE\n",
    ")\n",
    "\n",
    "full_loader = torch.utils.data.DataLoader(\n",
    "    full_dataset, batch_size=FULL_BATCH_SIZE\n",
    ")"
   ],
   "execution_count": 10,
   "outputs": [
    {
     "output_type": "stream",
     "text": "------------ Creating Dataset ------------\n------------ Dataset Created ------------\n------------ Creating DataLoader ------------\n",
     "name": "stdout"
    }
   ]
  },
  {
   "metadata": {
    "trusted": true,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "cell_type": "code",
   "source": [
    "print(\"------------ Dataloader Cretead ------------\")\n",
    "\n",
    "# print(train_loader)\n",
    "loss_fn = nn.MSELoss()\n",
    "\n",
    "encoder = ConvEncoder()\n",
    "decoder = ConvDecoder()\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    print(\"GPU Availaible moving models to GPU\")\n",
    "else:\n",
    "    print(\"Moving models to CPU\")\n",
    "\n",
    "encoder.to(device)\n",
    "decoder.to(device)\n",
    "\n",
    "# print(device)\n",
    "\n",
    "autoencoder_params = list(encoder.parameters()) + list(decoder.parameters())\n",
    "optimizer = optim.AdamW(autoencoder_params, lr=LEARNING_RATE)\n",
    "\n",
    "# early_stopper = utils.EarlyStopping(patience=5, verbose=True, path=)\n",
    "max_loss = 9999\n",
    "\n",
    "print(\"------------ Training started ------------\")\n",
    "\n",
    "for epoch in tqdm(range(EPOCHS)):\n",
    "    train_loss = train_step(\n",
    "        encoder, decoder, train_loader, loss_fn, optimizer, device=device\n",
    "    )\n",
    "    print(f\"Epochs = {epoch}, Training Loss : {train_loss}\")\n",
    "    val_loss = val_step(\n",
    "        encoder, decoder, val_loader, loss_fn, device=device\n",
    "    )\n",
    "\n",
    "    # Simple Best Model saving\n",
    "    if val_loss < max_loss:\n",
    "        print(\"Validation Loss decreased, saving new best model\")\n",
    "        torch.save(encoder.state_dict(), ENCODER_MODEL_PATH)\n",
    "        torch.save(decoder.state_dict(), DECODER_MODEL_PATH)\n",
    "\n",
    "    print(f\"Epochs = {epoch}, Validation Loss : {val_loss}\")\n",
    "\n",
    "print(\"Training Done\")"
   ],
   "execution_count": 11,
   "outputs": [
    {
     "output_type": "stream",
     "text": "------------ Dataloader Cretead ------------\nGPU Availaible moving models to GPU\n",
     "name": "stdout"
    },
    {
     "output_type": "stream",
     "text": "\r  0%|          | 0/2 [00:00<?, ?it/s]",
     "name": "stderr"
    },
    {
     "output_type": "stream",
     "text": "------------ Training started ------------\nEpochs = 0, Training Loss : 0.03702838718891144\n",
     "name": "stdout"
    },
    {
     "output_type": "stream",
     "text": "\r 50%|█████     | 1/2 [01:52<01:52, 112.76s/it]",
     "name": "stderr"
    },
    {
     "output_type": "stream",
     "text": "Validation Loss decreased, saving new best model\nEpochs = 0, Validation Loss : 0.034663040190935135\nEpochs = 1, Training Loss : 0.034218646585941315\n",
     "name": "stdout"
    },
    {
     "output_type": "stream",
     "text": "100%|██████████| 2/2 [03:35<00:00, 107.54s/it]",
     "name": "stderr"
    },
    {
     "output_type": "stream",
     "text": "Validation Loss decreased, saving new best model\nEpochs = 1, Validation Loss : 0.03623133897781372\nTraining Done\n",
     "name": "stdout"
    },
    {
     "output_type": "stream",
     "text": "\n",
     "name": "stderr"
    }
   ]
  },
  {
   "metadata": {
    "trusted": true,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "cell_type": "code",
   "source": [
    "embedding_dim = (1, 64, 64, 64)"
   ],
   "execution_count": 13,
   "outputs": []
  },
  {
   "metadata": {
    "trusted": true,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "cell_type": "code",
   "source": [
    "def create_embedding(encoder, full_loader, embedding_dim, device):\n",
    "    encoder.eval()\n",
    "    embedding = torch.randn(embedding_dim)\n",
    "    # print(embedding.shape)\n",
    "\n",
    "    with torch.no_grad():\n",
    "        for batch_idx, (train_img, target_img) in enumerate(full_loader):\n",
    "            train_img = train_img.to(device)\n",
    "            enc_output = encoder(train_img).cpu()\n",
    "            # print(enc_output.shape)\n",
    "            embedding = torch.cat((embedding, enc_output), 0)\n",
    "            # print(embedding.shape)\n",
    "    \n",
    "    return embedding\n"
   ],
   "execution_count": 15,
   "outputs": []
  },
  {
   "metadata": {
    "trusted": true,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "cell_type": "code",
   "source": [
    "embedding = create_embedding(encoder, full_loader, embedding_dim, device)"
   ],
   "execution_count": 2,
   "outputs": []
  },
  {
   "metadata": {
    "trusted": true,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "cell_type": "code",
   "source": [
    "print(embedding.shape)"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "metadata": {
    "trusted": true,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "cell_type": "code",
   "source": [
    "# embedding2 = embedding[4700:, :, :, ]"
   ],
   "execution_count": 18,
   "outputs": []
  },
  {
   "metadata": {
    "trusted": true,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "cell_type": "code",
   "source": [
    "# print(embedding2.shape)"
   ],
   "execution_count": 19,
   "outputs": [
    {
     "output_type": "stream",
     "text": "torch.Size([39, 64, 64, 64])\n",
     "name": "stdout"
    }
   ]
  },
  {
   "metadata": {
    "trusted": true,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "cell_type": "code",
   "source": [
    "numpy_embedding = embedding.cpu().detach().numpy()"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "metadata": {
    "trusted": true,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "cell_type": "code",
   "source": [
    "# numpy_embedding = embedding2.cpu().detach().numpy()"
   ],
   "execution_count": 20,
   "outputs": []
  },
  {
   "metadata": {
    "trusted": true,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "cell_type": "code",
   "source": [
    "print(numpy_embedding.shape)"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "metadata": {
    "trusted": true,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "cell_type": "code",
   "source": [
    "num_images = numpy_embedding.shape[0]\n",
    "# print(num_images)"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "metadata": {
    "trusted": true,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "cell_type": "code",
   "source": [
    "flattened_embedding = numpy_embedding.reshape((num_images, -1))"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "metadata": {
    "trusted": true,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "cell_type": "code",
   "source": [
    "print(flattened_embedding.shape)"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "metadata": {
    "trusted": true,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "cell_type": "code",
   "source": [
    "import numpy as np"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "metadata": {
    "trusted": true,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "cell_type": "code",
   "source": [
    "np.save(\"data_embedding_f.npy\", flattened_embedding)"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "metadata": {
    "trusted": true,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "cell_type": "code",
   "source": [
    "# np.save(\"../input/animals-data/data_embedding.npy\", flattened_embedding)"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "metadata": {
    "trusted": true,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "cell_type": "code",
   "source": [
    "flattend_embedding_reloaded = np.load(\"data_embedding.npy\")"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "metadata": {
    "trusted": true,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "cell_type": "code",
   "source": [
    "# encoder.eval()\n",
    "# embedding = torch.randn(embedding_dim)\n",
    "# print(embedding.shape)\n",
    "\n",
    "# with torch.no_grad():\n",
    "#     for batch_idx, (train_img, target_img) in enumerate(full_loader):\n",
    "#         train_img = train_img.to(device)\n",
    "# #         print(train_img.shape)\n",
    "        \n",
    "#         enc_output = encoder(train_img).cpu()\n",
    "#         print(enc_output.shape)\n",
    "\n",
    "#         embedding = torch.cat((embedding, enc_output), 0)\n",
    "#         print(embedding.shape)\n",
    "\n",
    "\n",
    "# #         break\n"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "metadata": {
    "trusted": true,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "cell_type": "code",
   "source": [],
   "execution_count": null,
   "outputs": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "name": "python3",
   "display_name": "Python 3",
   "language": "python"
  },
  "language_info": {
   "name": "python",
   "version": "3.7.6",
   "mimetype": "text/x-python",
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "pygments_lexer": "ipython3",
   "nbconvert_exporter": "python",
   "file_extension": ".py"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}